package com.haizhi.sparksql.demo;

import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Aggregator;

import java.io.Serializable;

import static org.apache.spark.sql.functions.udaf;

public class UDAF_ {
    public static void main(String[] args) {
        // 1.创建配置对象
        SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("sparkCore");

        // 2. 创建sparkContext
        SparkSession spark = SparkSession.builder().config(conf).getOrCreate();

        // 3. 编写代码
        spark.read().json("sparkSQL/input/user.json").createOrReplaceTempView("user");

        // 注册需要导入依赖 import static org.apache.spark.sql.functions.udaf;
        spark.udf().register("avgAge",udaf(new MyAvg(), Encoders.LONG()));

        spark.sql("select avgAge(age) newAge from user").show();

        //4. 关闭sparkSession
        spark.close();
    }

    public static class Buffer implements Serializable {
        private Long sum;
        private Long count;

        public Buffer() {
        }

        public Buffer(Long sum, Long count) {
            this.sum = sum;
            this.count = count;
        }

        public Long getSum() {
            return sum;
        }

        public void setSum(Long sum) {
            this.sum = sum;
        }

        public Long getCount() {
            return count;
        }

        public void setCount(Long count) {
            this.count = count;
        }
    }

    public static class MyAvg extends Aggregator<Long,Buffer,Double> {

        @Override
        public Buffer zero() {
            return new Buffer(0L, 0L);
        }

        @Override
        public Buffer reduce(Buffer b, Long a) {
            b.setSum(b.getSum() + a);
            b.setCount(b.getCount() + 1);
            return b;
        }

        @Override
        public Buffer merge(Buffer b1, Buffer b2) {

            b1.setSum(b1.getSum() + b2.getSum());
            b1.setCount(b1.getCount() + b2.getCount());

            return b1;
        }

        @Override
        public Double finish(Buffer reduction) {
            return reduction.getSum().doubleValue() / reduction.getCount();
        }

        @Override
        public Encoder<Buffer> bufferEncoder() {
            // 可以用kryo进行优化
            return Encoders.kryo(Buffer.class);
        }

        @Override
        public Encoder<Double> outputEncoder() {
            return Encoders.DOUBLE();
        }
    }

}
